Amazon SageMakerでRを使った独自コンテナを作成してみる

Amazon SageMakerでRを使った独自コンテナを作成してみる

Clock Icon2019.11.18

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

こんにちは、小澤です。

SageMakerを利用したい! けど、ビルドインの手法には含まれてないし、既存のライブラリ・フレームワークを使っても実現できない! となった場合、それ用のコンテナを作成してECRに登録すれば利用可能です。

その際利用する言語としては「SageMakerだしPythonかな?かな?」となりがちですが、 SageMakerから利用可能にするためにいくつか仕組みさえ提供していれば、実は言語は何でもいいんです。

今回はそんな仕組みを見ていくためにRを使って独自コンテナを作成してみます。

SageMakerの独自コンテナとDockerの中に入れておくもの

さて、Pythonを使って独自コンテナを作成する方法に関しては以下の情報を参照してください。

SageMakerの独自コンテナはECRに登録されたDockerコンテナを取得して利用するわけですが、 このコンテナの作り方に関して決まりごとはあまり多くありません。

  • docker run実行時に引数として学習時にはtrain、エンドポイントにはserveが渡される
  • trainが渡されたときには学習処理を実行する
  • ハイパーパラメータのJSONファイルや学習データは決まったパスに置かれる
  • serveが渡されたときには /ping/invocations にアクセスできるようにする

そのため、

  • ファイル入出力が可能である
  • Web API化が可能である
  • Dockerfile内記述のみでインストール可能である

を満たしている言語であれば、Pythonじゃなくても問題ありません。 必要な処理が実装できればあとは、それの実行をDockerfile内のENTRYPOINTで指定しておくだけです。

Rで実装してみる

Rで実装するに際しては、2つの.Rファイルを用意しています。 1つ目は、ENTRYPOINTとして実行されるファイルです。 また、学習時の処理もこのファイル内に記述してます。

2つ目は推論エンドポイントの処理を記述したものです。 API化する部分についてはplumberパッケージを利用していますので、実行時に指定するファイルに関しては別途用意しています。

trainかserveか

コンテナが実行されると、ENTRYPOINTで指定した処理が実行されます。 Rで実装する際には、スクリプトを記述したファイルを実行することになるため、 まずは、実行すべきは学習用の処理か推論エンドポイントの作成かを実行時に渡された引数から判断する処理を実装します。

args <- commandArgs()
if (any(grepl("train", args))) {
  train()
} else if (any(grepl("serve", args))) {
  serve()
}

学習時にはtrain関数、エンドポイント作成時にはserve関数を呼び出すようにしています。 train関数は学習処理の具体的な内容を記載しています。

train <- function() {
  # ここに学習処理を記述している...
}

推論エンドポイント作成時にはplumberの機能を利用するため必要な処理を記述したスクリプトを指定して実行しています。

serve <- function() {
  app <- plumb(paste(prefix, "plumber.R", sep="/"))
  app$run(host='0.0.0.0', port=8080)
}

学習処理の実装

続いて、trainの中身である学習処理の実装をします。 まずは、SageMakerの決まり事である、各種ファイルのパスを変数として定義しています。

prefix <- "/opt/ml"
input_path <- paste(prefix, "input/data", sep="/")
output_path <- paste(prefix, "output", sep="/")
model_path <- paste(prefix, "model", sep="/")
param_path <- paste(prefix, "input/config/hyperparameters.json", sep="/")

channel_name = "train"
training_path <- paste(input_path, channel_name, sep="/")

続いて、学習用の関数として定義したtrainの中身を実装します。

train <- function() {
  # ハイパーパラメータはJSONとしてファイルで渡されるので読み込んで取得する
  # 今回は
  #    必須 : 目的変数となる列名
  #    任意(デフォルト値あり) : randomForestのmtry
  # を受け取る
  training_params <- read_json(param_path)
  target <- training_params$target
  formula <- as.formula(paste(target, ".", sep="~"))
  if (!is.null(training_params$mtry)) {
    mtry <- as.numeric(training_params$mtry)}
  else {
    mtry <- 2
  }

  # 学習対象データの取得
  training_files = list.files(path=training_path, full.names=TRUE)
  training_data = do.call(rbind, lapply(training_files, read.csv))

  model <- randomForest(formula, data=training_data)
  # print出力はCloudWatch Logに吐き出される
  print(model)

  # 推論処理のために目的変数以外の列名を取得
  predict.names <- names(training_data)[-grep(target, names(training_data))]

  # save関数を使ってモデル出力パスに1つのファイルとして出力
  save(model, predict.names, file=paste(model_path, "rf_model.data", sep="/"))
  write("success", file=paste(output_path, "success", sep='/'))
}

学習処理実行時にはこの関数が呼び出されたのち、プログラムが終了するので学習用コンテナは自動的に終了します。

推論処理の実装

続いて、推論処理の実装を見ていきます。

こちらは、plumberパッケージのplumb関数で指定したファイル内の処理を記載します。

load("/opt/ml/model/rf_model.data")

## SageMakerがコンテナの死活監視をするために定期的に呼び出す処理
## アクセスできることが確認できればいいので返す値は何でもいい
#' Ping to show server is there
#' @get /ping
function() {
  return("")
}

## 推論エンドポイントとしてアクセスされる
## 推論を行ってその結果を返す
#' Parse input and return prediction from model
#' @param req The http request sent
#' @post /invocations
function(req) {
  # 入力は文字列して与えられるので、textConnectionを使ってdata frameに変換する
  # ここではヘッダなしのcsv形式で渡されることを想定
  conn <- textConnection(gsub("\\\\n", "\n", req$postBody))
  data <- read.csv(conn, header = FALSE)
  close(conn)

  # 保存しておいた目的変数以外の列名を受け取ったデータに付与
  names(data) <- predict.names

  as.character(predict(model, newdata = data))
}

plumberのrunは制御を返さない処理なので、そのまま動き続けます。 そのため、処理が終了してコンテナが自動的に削除されることもありません。

Dockerファイルの作成

Rを使った独自コンテナ内の処理の実装は以上となります。 今回行っている処理は

  • ENTRYPOINTで指定したスクリプトを実行し、 docker runの引数がtrainかserveで分岐
  • trainの場合学習処理の実行
  • serveの場合 /ping/invocations でHTTPアクセス可能な状態にする

のみです。 この動きさえ実装していれば、言語は問わないこともご理解いただけるかと思います。

さて、最後にDockerfileも確認しておきましょう。

FROM ubuntu:16.04

MAINTAINER <<あなたの情報をここに入れてね>>

RUN apt-get -y update && apt-get install -y --no-install-recommends \
wget \
r-base \
r-base-dev \
ca-certificates

RUN R -e "install.packages(c('randomForest', 'plumber'), repos='https://cloud.r-project.org')"

COPY random_forest.R /opt/ml/random_forest.R
COPY plumber.R /opt/ml/plumber.R

ENTRYPOINT ["/usr/bin/Rscript", "/opt/ml/random_forest.R", "--no-save"]

何の変哲もないubuntuに対して以下を行っています。

  • aptでRのインストール
  • Rの処理を実行して利用するパッケージのインストール
  • 処理を記述したファイルを入れる
  • 実行するファイルの指定

こちらからも各種インストールやENTRYPOINTでの指定が可能な言語であれば何を使っても問題ないことを確認できます。

登録したコンテナを使ってみる

最後にこのコンテナを使ってみましょう。 とはいえ、使い方は通常と同様なので、学習および推論の部分だけ確認します。 せっかくなので、こちらもRで実装しています。

# 登録したECRのコンテナを指定する
# その他の処理はビルドインの時と同様
container <- "<your aws accountid>.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-r-randomforest:latest"
estimator <- sagemaker$estimator$Estimator(image_name = container,
                                           role = role,
                                           train_instance_count = 1L,
                                           train_instance_type = 'ml.m5.large',
                                           train_volume_size = 30L,
                                           train_max_run = 3600L,
                                           input_mode = 'File',
                                           output_path = s3_output,
                                           sagemaker_session = session)
estimator$set_hyperparameters(target = "Species")
estimator$fit(inputs = list("train"=s3_train))

# 推論エンドポイント作成もそのまま利用可能
# 作成後も同様に同じように使えるため割愛
model_endpoint <- estimator$deploy(initial_instance_count = 1L,
                                   instance_type = 'ml.t2.medium')

おわりに

今回は、SageMakerでRで実装してみました。

「SageMakerだからPythonで実装かな?かな?」と思われてた方も、 「あ、他の言語でもいいんだ!」と思っていただけたかな?かな?

ライブラリの充実度などの都合で使うかどうかはまた別な話ですが、速さを求めてC/C++を検討する方には朗報なのでしょうか?(私にはそこまでできないのでわからないですが...)

あと、どうでもいいですが「かな?」を2回続けるのがひそかなマイブームです。

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.